# Copyright 2018 Google. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Train NMT with low level API."""

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import six
import threading
import time

import tensorflow as tf

from tensorflow.contrib import tpu
from tensorflow.python.data.util import nest as data_nest
from tensorflow.python.framework import graph_io
from mlperf_compliance import mlperf_log
from utils import iterator_utils
from utils import vocab_utils

from lottery import lottery

_INITIAL_LOSS = 1e7


def wrap_computation_in_while_loop(op_fn, n, host_name):
  """Wraps the ops generated by `op_fn` in tf.while_loop."""

  def computation(i):
    ops = op_fn()
    if not isinstance(ops, list):
      ops = [ops]
    with tf.control_dependencies(ops):
      return i + 1

  with tf.device(device_for_host(host_name)):
    return tf.while_loop(
        lambda i: tf.less(i, n),
        computation, [tf.constant(0)],
        parallel_iterations=1)


def get_resolver(hparams):
  if hparams.master:
    return tf.contrib.cluster_resolver.TPUClusterResolver(hparams.master)
  elif hparams.tpu_name:
    return tf.contrib.cluster_resolver.TPUClusterResolver(hparams.tpu_name)
  else:
    return None


def get_host(resolver, host_id=0):
  if resolver is None:
    return "/replica:0/task:0"
  else:
    job_name = resolver.get_job_name() or "tpu_worker"
    return "/job:%s/task:%d" % (job_name, host_id)


def device_for_host(host_name):
  return host_name + "/device:CPU:0"


def device_for_tpu_core(host_name, core=0):
  return host_name + "/device:TPU_REPLICATED_CORE:%d" % core


class TrainLowLevelRunner(object):
  """Run Train via direct session.run calls."""

  def __init__(self, iterations, hparams, per_host_v1=False):
    tf.logging.info("TrainLowLevelRunner: constructor")

    self.feature_structure = {}
    self.loss = None
    self.infeed_queue = []
    self.enqueue_ops = []
    self.dataset_initializer = []
    self.is_local = ((hparams.master == "") and (hparams.tpu_name is None))
    self.per_host_v1 = per_host_v1
    self.iterations = iterations
    self.sess = None
    self.graph = tf.Graph()
    self.hparams = hparams
    with self.graph.as_default():
      self.tpu_init = [tpu.initialize_system()]
      self.tpu_shutdown = tpu.shutdown_system()

    self.resolver = get_resolver(hparams)
    session_config = tf.ConfigProto(allow_soft_placement=True,
                                    isolate_session_state=True)
    if self.hparams.tpu_name is None:
      master = self.hparams.master
    else:
      cluster_spec = self.resolver.cluster_spec()
      tf.logging.info(cluster_spec)
      if cluster_spec:
        session_config.cluster_def.CopyFrom(cluster_spec.as_cluster_def())
      master = self.resolver.get_master()
    self.sess = tf.Session(master, graph=self.graph, config=session_config)
    self.sess.run(self.tpu_init)

    self.hooks = lottery.hooks_from_flags(hparams.values())

  def initialize(self, input_fn, params):
    """Initialize all the things required for training."""
    tf.logging.info("TrainLowLevelRunner: initialize method")

    num_hosts = self.hparams.num_shards // self.hparams.num_shards_per_host

    def get_enqueue_ops_fn(host_id):
      """Generate the enqueue ops graph function."""

      params["dataset_num_shards"] = num_hosts
      params["dataset_index"] = host_id

      output = input_fn(params)
      device = device_for_host(get_host(self.resolver, host_id))
      with tf.device(device):
        if self.per_host_v1:
          iterator = input_fn._iterator
        else:
          # output is a dateset
          iterator = output.make_initializable_iterator()
        self.dataset_initializer.append(iterator.initializer)

        def enqueue_ops_fn_v1():
          """Enqueue ops function for one host.."""
          features = output
          self.feature_structure["features"] = features
          self.feature_structure["labels"] = {}
          flattened_inputs = data_nest.flatten(self.feature_structure)
          infeed = tpu.InfeedQueue(
              tuple_types=[t.dtype for t in flattened_inputs],
              tuple_shapes=[t.shape for t in flattened_inputs],
              shard_dimensions=None)

          infeed.set_number_of_shards(self.hparams.num_shards_per_host)
          self.infeed_queue.append(infeed)

          def tpu_ordinal_fn(shard_index_in_host):
            return shard_index_in_host % self.hparams.num_shards_per_host

          per_host_enqueue_ops = (
              infeed.split_inputs_and_generate_enqueue_ops(
                  flattened_inputs,
                  placement_function=lambda x: device,
                  tpu_ordinal_function=tpu_ordinal_fn))
          return per_host_enqueue_ops

        def enqueue_ops_fn_v2():
          """Enqueue ops function for one host."""
          per_host_sharded_inputs = []
          control_deps = []
          for _ in range(self.hparams.num_shards_per_host):
            with tf.control_dependencies(control_deps):
              features = iterator.get_next()
            self.feature_structure["features"] = features
            self.feature_structure["labels"] = {}
            flattened_inputs = data_nest.flatten(self.feature_structure)
            control_deps.extend(flattened_inputs)
            per_host_sharded_inputs.append(flattened_inputs)

          infeed = tpu.InfeedQueue(
              number_of_tuple_elements=len(per_host_sharded_inputs[0]))
          self.infeed_queue.append(infeed)

          def tpu_ordinal_fn(shard_index_in_host):
            return shard_index_in_host % self.hparams.num_shards_per_host

          return infeed.generate_enqueue_ops(
              per_host_sharded_inputs,
              tpu_ordinal_function=tpu_ordinal_fn)

      if self.per_host_v1:
        return enqueue_ops_fn_v1
      else:
        return enqueue_ops_fn_v2

    with self.graph.as_default():
      for i in range(num_hosts):
        if self.per_host_v1:
          self.enqueue_ops.append(get_enqueue_ops_fn(i)())
        else:
          self.enqueue_ops.append(
            wrap_computation_in_while_loop(
                get_enqueue_ops_fn(i), n=self.iterations,
                host_name=get_host(self.resolver, i)))
        init_tables = tf.tables_initializer()

    self.sess.run(init_tables)
    # Initialize dataset variables
    self.sess.run(self.dataset_initializer)

  def build_model(self, model_fn, params):
    """Build the TPU model and infeed enqueue ops."""
    tf.logging.info("TrainLowLevelRunner: build_model method")

    def tpu_train_step(loss):
      """Generate the TPU graph."""
      del loss
      values = self.infeed_queue[0].generate_dequeue_op(tpu_device=0)
      unflattened_inputs = data_nest.pack_sequence_as(self.feature_structure,
                                                      values)
      features = unflattened_inputs["features"]
      labels = unflattened_inputs["labels"]
      estimator_spec = model_fn(features, labels, tf.estimator.ModeKeys.TRAIN,
                                params)
      loss, train_op = estimator_spec.loss, estimator_spec.train_op
      with tf.control_dependencies([train_op]):
        return tf.identity(loss)

    def train_loop():
      return tpu.repeat(self.iterations, tpu_train_step, [_INITIAL_LOSS])

    with self.graph.as_default():
      (self.loss,) = tpu.shard(
          train_loop,
          inputs=[],
          num_shards=self.hparams.num_shards,
          outputs_from_all_shards=False,
      )
      global_initializer = tf.global_variables_initializer()
      local_initializer = tf.local_variables_initializer()
      graph_io.write_graph(
          self.graph.as_graph_def(add_shapes=True), self.hparams.out_dir,
          "graph.pbtxt")
      self.saver = tf.train.Saver()


    self.sess.run(global_initializer)
    self.sess.run(local_initializer)

    checkpoint_path = tf.train.latest_checkpoint(self.hparams.out_dir)
    if checkpoint_path:
      self.saver.restore(self.sess, checkpoint_path)

    with self.graph.as_default():
      for hook in self.hooks:
        hook.after_create_session(self.sess, None)

  def get_global_step(self):
    with self.graph.as_default():
      return self.sess.run(tf.train.get_global_step())

  def train(self, start_step, train_steps, num_threads=2):
    """Run the Train loop on the TPU device."""

    tf.logging.info("TrainLowLevelRunner: train for %d steps in total",
                    train_steps)

    def infeed_thread_fn(sess, enqueue_ops):
      assert train_steps % self.iterations == 0
      if self.per_host_v1:
        steps = train_steps
      else:
        steps = train_steps // self.iterations
      for _ in range(steps):
        sess.run([enqueue_ops])

    def checkpoint_thread_fn(saver, sess):
      saver.save(sess, self.hparams.out_dir + "/model.ckpt-%d" % (start_step + cur_step))

    infeed_thread = threading.Thread(
        target=infeed_thread_fn, args=(self.sess, self.enqueue_ops))
    infeed_thread.start()

    cur_step = 0
    thread_id = 0
    checkpoint_threads = []
    for i in range(num_threads):
      checkpoint_threads.append(None)
    while cur_step < train_steps:
      start = time.time()
      tf.logging.info("TrainLowLevelRunner: start train step:%d", cur_step)
      cur_step += self.iterations
      loss = self.sess.run([self.loss])

      with self.graph.as_default():
        for hook in self.hooks:
          hook.after_run(tf.train.SessionRunContext([self.loss], self.sess), [loss])

      tf.logging.info("TrainLowLevelRunner: sess run loss: %s", loss)

      if checkpoint_threads[thread_id] is not None:
        checkpoint_threads[thread_id].join()
      checkpoint_threads[thread_id] = threading.Thread(
          target=checkpoint_thread_fn, args=(self.saver, self.sess))
      checkpoint_threads[thread_id].start()
      thread_id += 1
      if thread_id >= num_threads:
        thread_id = 0

      end = time.time()
      tf.logging.info(
          "TrainLowLevelRunner: step time {} sec {} examples/sec".format(
              end - start,
              self.iterations * self.hparams.batch_size / (end - start)))

    infeed_thread.join()

    for i in range(num_threads):
      if checkpoint_threads[i] is not None:
        checkpoint_threads[i].join()
        checkpoint_threads[i] = None

    with self.graph.as_default():
      for hook in self.hooks:
        hook.end(self.sess)


class EvalLowLevelRunner(object):
  """Run eval via direct session.run calls."""

  def __init__(self, eval_steps, hparams):
    tf.logging.info("EvalLowLevelRunner: constructor")
    tf.logging.info("eval_steps: %s", eval_steps)

    self.feature_structure = {}
    self.infeed_queue = []
    self.enqueue_ops = []
    self.dataset_initializer = []
    self.is_local = ((hparams.master == "") and (hparams.tpu_name is None))
    self.eval_steps = eval_steps
    self.sess = None
    self.eval_op = None
    self.graph = tf.Graph()
    self.hparams = hparams
    self.outfeed_tensors = []
    self.outfeed_names = []
    self.dequeue_ops = {}
    self.saver = None
    with self.graph.as_default():
      self.tpu_init = [tpu.initialize_system()]
      self.tpu_shutdown = tpu.shutdown_system()

    self.resolver = get_resolver(hparams)
    session_config = tf.ConfigProto(
        allow_soft_placement=True,
        operation_timeout_in_ms=600 * 60 * 1000)  # 10 hours

    if self.hparams.tpu_name is None:
      master = self.hparams.master
    else:
      cluster_spec = self.resolver.cluster_spec()
      if cluster_spec:
        session_config.cluster_def.CopyFrom(cluster_spec.as_cluster_def())
      master = self.resolver.get_master()

    self.sess = tf.Session(
        master,
        graph=self.graph,
        config=session_config)
    self.sess.run(self.tpu_init)

  def initialize(self, input_fn, params):
    """Initialize all the things required for evaluation."""
    tf.logging.info("EvalLowLevelRunner: initialize method")

    def get_enqueue_ops_fn():
      """Generate the enqueue ops graph function."""

      dataset = input_fn(params)
      with tf.device(device_for_host(get_host(self.resolver))):
        iterator = dataset.make_initializable_iterator()
        self.dataset_initializer.append(iterator.initializer)

        def enqueue_ops_fn():
          """Enqueue ops function for one host."""
          per_host_sharded_inputs = []
          control_deps = []
          for _ in range(self.hparams.num_shards_per_host):
            with tf.control_dependencies(control_deps):
              features = iterator.get_next()
            self.feature_structure["features"] = features
            flattened_inputs = data_nest.flatten(self.feature_structure)
            control_deps.extend(flattened_inputs)
            per_host_sharded_inputs.append(flattened_inputs)

          infeed = tpu.InfeedQueue(
              number_of_tuple_elements=len(per_host_sharded_inputs[0]))
          self.infeed_queue.append(infeed)

          def tpu_ordinal_fn(shard_index_in_host):
            return shard_index_in_host % self.hparams.num_shards_per_host

          return infeed.generate_enqueue_ops(
              per_host_sharded_inputs,
              tpu_ordinal_function=tpu_ordinal_fn)

        return enqueue_ops_fn

    with self.graph.as_default():
      self.enqueue_ops.append(
          wrap_computation_in_while_loop(
              get_enqueue_ops_fn(), n=self.eval_steps,
              host_name=get_host(self.resolver)))
      init_tables = tf.tables_initializer()

    self.sess.run(init_tables)

  def build_model(self, model_fn, params):
    """Build the TPU model and infeed enqueue ops."""
    tf.logging.info("EvalLowLevelRunner: build_model method")

    def tpu_eval_step():
      """Generate the TPU graph."""
      values = self.infeed_queue[0].generate_dequeue_op(tpu_device=0)
      unflattened_inputs = data_nest.pack_sequence_as(self.feature_structure,
                                                      values)
      features = unflattened_inputs["features"]
      estimator_spec = model_fn(features, None, tf.estimator.ModeKeys.PREDICT,
                                params)
      for k, v in six.iteritems(estimator_spec.predictions):
        self.outfeed_names.append(k)
        self.outfeed_tensors.append(v)

      with tf.device(device_for_tpu_core(get_host(self.resolver))):
        outfeed_enqueue_ops = tpu.outfeed_enqueue_tuple(self.outfeed_tensors)
      with tf.control_dependencies([outfeed_enqueue_ops]):
        return tf.no_op()

    def eval_loop():
      return tpu.repeat(self.eval_steps, tpu_eval_step, [])

    def create_dequeue_ops():
      """Create outfeed dequeue ops."""
      dequeue_ops = []
      tensor_dtypes = []
      tensor_shapes = []
      for v in self.outfeed_tensors:
        dequeue_ops.append([])
        tensor_dtypes.append(v.dtype)
        tensor_shapes.append(v.shape)
      for i in range(self.hparams.num_shards):
        with tf.device(device_for_host(get_host(self.resolver))):
          outfeed_tensors = tpu.outfeed_dequeue_tuple(
              dtypes=tensor_dtypes, shapes=tensor_shapes, device_ordinal=i)
          for j, item in enumerate(outfeed_tensors):
            dequeue_ops[j].append(item)
      for j in range(len(outfeed_tensors)):
        dequeue_ops[j] = tf.concat(dequeue_ops[j], axis=0)
      return dequeue_ops

    with self.graph.as_default():
      (self.eval_op,) = tpu.shard(
          eval_loop,
          inputs=[],
          num_shards=self.hparams.num_shards,
          outputs_from_all_shards=False,
      )

      for i, dequeue_tenor in enumerate(create_dequeue_ops()):
        self.dequeue_ops[self.outfeed_names[i]] = dequeue_tenor

      self.saver = tf.train.Saver()

  def predict(self, checkpoint_path=None):
    """Run the predict loop on the TPU device."""
    if not checkpoint_path:
      checkpoint_path = tf.train.latest_checkpoint(self.hparams.out_dir)

    self.saver.restore(self.sess, checkpoint_path)
    # Initialize dataset variables
    self.sess.run(self.dataset_initializer)

    # Infeed thread.
    def infeed_thread_fn(sess, enqueue_ops):
      sess.run([enqueue_ops])

    infeed_thread = threading.Thread(
        target=infeed_thread_fn, args=(self.sess, self.enqueue_ops))
    infeed_thread.start()

    # Eval thread.
    def eval_thread_fn(sess, eval_op):
      sess.run([eval_op])

    eval_thread = threading.Thread(
        target=eval_thread_fn, args=(self.sess, self.eval_op))
    eval_thread.start()

    for step in range(self.eval_steps):
      tf.logging.info("EvalLowLevelRunner: start eval step:%d", step)
      predictions = self.sess.run(self.dequeue_ops)
      for i in range(self.hparams.infer_batch_size):
        yield {key: value[i] for key, value in six.iteritems(predictions)}

    infeed_thread.join()
    eval_thread.join()
